{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Code for **\"Prior effect\"** figure from supmat."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "*Uncomment if running on colab* \n",
    "Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab \n",
    "\"\"\"\n",
    "# !git clone https://github.com/DmitryUlyanov/deep-image-prior\n",
    "# !mv deep-image-prior/* ./"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import argparse\n",
    "import os\n",
    "# os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "import numpy as np\n",
    "from models import *\n",
    "\n",
    "import torch\n",
    "import torch.optim\n",
    "\n",
    "from skimage.measure import compare_psnr\n",
    "from models.downsampler import Downsampler\n",
    "\n",
    "from utils.sr_utils import *\n",
    "\n",
    "torch.backends.cudnn.enabled = True\n",
    "torch.backends.cudnn.benchmark =True\n",
    "dtype = torch.cuda.FloatTensor\n",
    "\n",
    "imsize =-1 \n",
    "factor = 4\n",
    "enforse_div32 = 'CROP' # we usually need the dimensions to be divisible by a power of two\n",
    "\n",
    "PLOT = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fname  = 'data/sr/zebra_crop.png'\n",
    "\n",
    "imgs = load_LR_HR_imgs_sr(fname, imsize, factor, enforse_div32)\n",
    "\n",
    "if PLOT:\n",
    "    imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np'] = get_baselines(imgs['LR_pil'], imgs['HR_pil'])\n",
    "    plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np']], 4,12);\n",
    "    print ('PSNR bicubic: %.4f   PSNR nearest: %.4f' %  (\n",
    "                                        compare_psnr(imgs['HR_np'], imgs['bicubic_np']), \n",
    "                                        compare_psnr(imgs['HR_np'], imgs['nearest_np'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def closure():\n",
    "    \n",
    "    global i, net_input\n",
    "    \n",
    "    \n",
    "    if reg_noise_std > 0:\n",
    "        net_input = net_input_saved + (noise.normal_() * reg_noise_std)\n",
    "        \n",
    "    out_HR = net(net_input)\n",
    "    out_LR = downsampler(out_HR)\n",
    "\n",
    "    total_loss = mse(out_LR, img_LR_var) + tv_weight * tv_loss(out_HR)\n",
    "    total_loss.backward()\n",
    "\n",
    "    # Log\n",
    "    psnr_LR = compare_psnr(imgs['LR_np'], torch_to_np(out_LR))\n",
    "    psnr_HR = compare_psnr(imgs['HR_np'], torch_to_np(out_HR))\n",
    "    print ('Iteration %05d    PSNR_LR %.3f   PSNR_HR %.3f' % (i, psnr_LR, psnr_HR), '\\r', end='')\n",
    "                      \n",
    "    # History\n",
    "    psnr_history.append([psnr_LR, psnr_HR])\n",
    "    \n",
    "    if PLOT and i % 500 == 0:\n",
    "        out_HR_np = torch_to_np(out_HR)\n",
    "        plot_image_grid([imgs['HR_np'], np.clip(out_HR_np, 0, 1)], factor=8, nrow=2, interpolation='lanczos')\n",
    "\n",
    "    i += 1\n",
    "    \n",
    "    return total_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment 1: no prior, optimize over pixels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_depth = 3\n",
    " \n",
    "INPUT = 'noise'\n",
    "pad = 'reflection'\n",
    "OPT_OVER = 'input'\n",
    "KERNEL_TYPE='lanczos2'\n",
    "\n",
    "LR = 0.01\n",
    "tv_weight = 0.0\n",
    "\n",
    "OPTIMIZER = 'adam'\n",
    "\n",
    "num_iter = 2000\n",
    "reg_noise_std = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Identity mapping network, optimize over `net_input`\n",
    "net = nn.Sequential()\n",
    "net_input = get_noise(input_depth, INPUT, (imgs['HR_pil'].size[1], imgs['HR_pil'].size[0])).type(dtype).detach()\n",
    "\n",
    "downsampler = Downsampler(n_planes=3, factor=factor, kernel_type='lanczos2', phase=0.5, preserve_size=True).type(dtype)\n",
    "\n",
    "# Loss\n",
    "mse = torch.nn.MSELoss().type(dtype)\n",
    "\n",
    "img_LR_var = np_to_torch(imgs['LR_np']).type(dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "psnr_history = [] \n",
    "i = 0\n",
    "net_input_saved = net_input.detach().clone()\n",
    "noise = net_input.detach().clone()\n",
    " \n",
    "p = get_params(OPT_OVER, net, net_input)\n",
    "optimize(OPTIMIZER, p, closure, LR, num_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_HR_np = np.clip(torch_to_np(net(net_input)), 0, 1)\n",
    "\n",
    "result_no_prior = put_in_center(out_HR_np, imgs['orig_np'].shape[1:])\n",
    "psnr_history_direct = psnr_history"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment 2: using TV loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tv_weight = 1e-7\n",
    "net_input = get_noise(input_depth, INPUT, (imgs['HR_pil'].size[1], imgs['HR_pil'].size[0])).type(dtype).detach()\n",
    "\n",
    "psnr_history = [] \n",
    "i = 0\n",
    " \n",
    "p = get_params(OPT_OVER, net, net_input)\n",
    "optimize(OPTIMIZER, p, closure, LR, num_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_HR_np = np.clip(torch_to_np(net(net_input)), 0, 1)\n",
    "\n",
    "result_tv_prior = put_in_center(out_HR_np, imgs['orig_np'].shape[1:])\n",
    "psnr_history_tv = psnr_history"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment 3: using deep prior"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Same setting, but use parametrization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "OPT_OVER = 'net'\n",
    "reg_noise_std = 1./30. # This parameter probably should be set to a lower value for this example\n",
    "tv_weight = 0.0\n",
    "\n",
    "net = skip(input_depth, 3, num_channels_down = [128, 128, 128, 128, 128], \n",
    "                           num_channels_up   = [128, 128, 128, 128, 128],\n",
    "                           num_channels_skip = [4, 4, 4, 4, 4], \n",
    "                           upsample_mode='bilinear',\n",
    "                           need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)\n",
    "\n",
    "net_input = get_noise(input_depth, INPUT, (imgs['HR_pil'].size[1], imgs['HR_pil'].size[0])).type(dtype).detach()\n",
    "\n",
    "# Compute number of parameters\n",
    "s  = sum([np.prod(list(p.size())) for p in net.parameters()]); \n",
    "print ('Number of params: %d' % s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "psnr_history = [] \n",
    "net_input_saved = net_input.detach().clone()\n",
    "noise = net_input.detach().clone()\n",
    "\n",
    "i = 0\n",
    "p = get_params(OPT_OVER, net, net_input)\n",
    "optimize(OPTIMIZER, p, closure, LR, num_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_HR_np = np.clip(torch_to_np(net(net_input)), 0, 1)\n",
    "\n",
    "result_deep_prior = put_in_center(out_HR_np, imgs['orig_np'].shape[1:])\n",
    "psnr_history_deep_prior = psnr_history"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_image_grid([imgs['HR_np'], \n",
    "                 result_no_prior, \n",
    "                 result_tv_prior, \n",
    "                 result_deep_prior], factor=8, nrow=2);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
